-
Notifications
You must be signed in to change notification settings - Fork 629
feat: enable GRPO training with logprobs from offline trajectory data #467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
JRMeyer
wants to merge
9
commits into
OpenPipe:main
Choose a base branch
from
JRMeyer:fix/warn-engine-args-in-openai-server-config
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
feat: enable GRPO training with logprobs from offline trajectory data #467
JRMeyer
wants to merge
9
commits into
OpenPipe:main
from
JRMeyer:fix/warn-engine-args-in-openai-server-config
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
dd383f5 to
5b26fd9
Compare
63d68c0 to
e859dd9
Compare
Add a runtime warning when users pass engine-initialization-only arguments (max_logprobs, gpu_memory_utilization, tensor_parallel_size, max_model_len) via OpenAIServerConfig.engine_args. These arguments are silently ignored because the vLLM engine is initialized by Unsloth before OpenAIServerConfig is applied. The warning guides users to use TrainableModel._internal_config instead.
The _internal_config field was being lost when TrainableModel was deserialized from JSON (e.g., when sent from client to SkyPilot backend). This is because Pydantic ignores fields starting with underscore during model_validate(). Added a model_validator(mode="wrap") that extracts _internal_config from the input data before validation and sets it after the model is created. This fixes the "Cannot request more than 0 logprobs" error when using _internal_config.engine_args with remote backends. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Adds three new metrics logged during training to help users verify that importance sampling is working correctly: - frac_old_logprobs_valid: Fraction of old logprobs that are not NaN - mean_importance_ratio: Mean π_new/π_old across assistant tokens - clip_fraction: Fraction of tokens where PPO clipping was triggered These metrics help diagnose whether GRPO/PPO importance sampling is active or if training has fallen back to vanilla REINFORCE (when all logprobs are NaN). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
e859dd9 to
0a64c7c
Compare
Supports three formats in priority order: 1. New format: token_ids + logprobs.values (direct arrays) 2. Old format: logprobs.content with token_id:XXX parsing 3. No logprobs: re-tokenize with NaN logprobs Fixes token count mismatch that caused frac_old_logprobs_valid=0. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Skip trajectories where the final assistant message is stripped by the chat template (e.g., when it only contains <think> content), causing continue_final_message=True to fail. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR enables proper GRPO training with importance sampling when using offline trajectory data (e.g., from vLLM traces). It includes four complementary changes:
1. Extract logprobs from dict messages
Problem: ART's tokenizer only extracted logprobs from OpenAI
Choiceobjects, but offline trajectory data often stores logprobs in plain Python dicts. This caused all dict message logprobs to be set to NaN, making the importance ratio = 1.0 always (effectively REINFORCE instead of GRPO).Solution: Modified
tokenize.pyto also extract logprobs from dict messages that have the format{"logprobs": {"content": [{"logprob": -0.5}, ...]}}.2. Strip logprobs before RULER scoring
Problem: When trajectories contain verbose logprobs data, sending them to the RULER judge causes context length errors.
Solution: Strip logprobs from trajectories before sending to RULER using
strip_logprobs().3. Preserve
_internal_config.engine_argsProblem: When using
TrainableModel._internal_config.engine_argsto configure vLLM engine settings (likemax_logprobs), the configuration was silently lost when using the SkyPilot backend.Solution: Add a
model_validator(mode="wrap")to preserve_internal_configduring Pydantic deserialization.4. Add importance sampling observability metrics
Problem: ART computes importance sampling ratios internally but doesn't expose them, making it impossible to verify if importance sampling is actually working.
Solution: Add three new metrics logged during training:
frac_old_logprobs_valid: Fraction of old logprobs that are not NaN (0 = no importance sampling)mean_importance_ratio: Mean π_new/π_old across assistant tokens (should vary around 1.0)clip_fraction: Fraction of tokens where PPO clipping was triggered (>0 means off-policy correction active)Impact
π_new / π_oldfrac_old_logprobs_valid,mean_importance_ratio,clip_fractionNew Metrics Interpretation
frac_old_logprobs_validmean_importance_ratioclip_fractionTest plan
max_logprobssetting works with SkyPilot backend./scripts/run_checks.sh- all checks pass